{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import random\n",
    "import math\n",
    "import csv\n",
    "import datetime\n",
    "from sklearn import svm\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import f1_score\n",
    "import multiprocessing as mp\n",
    "from sklearn import svm\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import average_precision_score\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.preprocessing import normalize\n",
    "import numpy as np\n",
    "from sklearn import linear_model\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.preprocessing import StandardScaler  \n",
    "from sklearn.linear_model import LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def CommonNeighbors(u, v, g):\n",
    "    u_neighbors = set(g.neighbors(u))\n",
    "    v_neighbors = set(g.neighbors(v))\n",
    "    return len(u_neighbors.intersection(v_neighbors))\n",
    "def common_neighbors(g, edges):\n",
    "    result = []\n",
    "    for edge in edges:\n",
    "        node_one, node_two = edge[0], edge[1]\n",
    "        num_common_neighbors = 0\n",
    "        try:\n",
    "            neighbors_one, neighbors_two = g.neighbors(node_one), g.neighbors(node_two)\n",
    "            for neighbor in neighbors_one:\n",
    "                if neighbor in neighbors_two:\n",
    "                    num_common_neighbors += 1\n",
    "            result.append((node_one, node_two, num_common_neighbors))\n",
    "        except:\n",
    "            pass\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def AdamicAdar(u, v, g):\n",
    "    u_neighbors = set(g.neighbors(u))\n",
    "    v_neighbors = set(g.neighbors(v))\n",
    "    aa = 0\n",
    "    for i in u_neighbors.intersection(v_neighbors):\n",
    "        aa += 1 / math.log(len(g.neighbors(i)))\n",
    "    return aa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ResourceAllocation(u, v, g):\n",
    "    u_neighbors = set(g.neighbors(u))\n",
    "    v_neighbors = set(g.neighbors(v))\n",
    "    ra = 0\n",
    "    for i in u_neighbors.intersection(v_neighbors):\n",
    "        ra += 1 / float(len(g.neighbors(i)))\n",
    "    return ra"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def JaccardCoefficient(u, v, g):\n",
    "    u_neighbors = set(g.neighbors(u))\n",
    "    v_neighbors = set(g.neighbors(v))\n",
    "    return len(u_neighbors.intersection(v_neighbors)) / float(len(u_neighbors.union(v_neighbors)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def PreferentialAttachment(u, v, g):\n",
    "    return len(g.neighbors(u))*len(g.neighbors(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def AllFeatures(u,v,g1, g2):\n",
    "    '''\n",
    "    the change of features in two consecutive sub graphs\n",
    "    '''\n",
    "    try:\n",
    "        cn = CommonNeighbors(u, v, g2)\n",
    "        aa = AdamicAdar(u, v, g2)\n",
    "        ra = ResourceAllocation(u, v, g2)\n",
    "        jc = JaccardCoefficient(u, v, g2)\n",
    "        pa = PreferentialAttachment(u, v, g2)\n",
    "\n",
    "        delta_cn = cn - CommonNeighbors(u, v, g1)\n",
    "        delta_aa = aa - AdamicAdar(u, v, g1)\n",
    "        delta_ra = ra - ResourceAllocation(u, v, g1)\n",
    "        delta_jc = jc - JaccardCoefficient(u, v, g1)\n",
    "        delta_pa = pa - PreferentialAttachment(u, v, g1)\n",
    "        return {\"cn\":cn, \"aa\": aa, \"ra\":ra, \"jc\":jc, \"pa\":pa,\n",
    "            \"delta_cn\": delta_cn, \"delta_aa\": delta_aa, \"delta_ra\": delta_ra,\n",
    "             \"delta_jc\": delta_jc, \"delta_pa\": delta_pa}\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_set = [common_neighbors,\n",
    "                   nx.resource_allocation_index,\n",
    "                   nx.jaccard_coefficient,\n",
    "                   nx.adamic_adar_index,\n",
    "                   nx.preferential_attachment\n",
    "                   ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def produce_fake_edge(g, neg_g,num_test_edges):\n",
    "    i = 0\n",
    "    while i < num_test_edges:\n",
    "        edge = random.sample(g.nodes(), 2)\n",
    "        try:\n",
    "            shortest_path = nx.shortest_path_length(g,source=edge[0],target=edge[1])\n",
    "            if shortest_path >= 2:\n",
    "                neg_g.add_edge(edge[0],edge[1], positive=\"False\")\n",
    "                i += 1\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_graph_from_file(filename):\n",
    "    print(\"----------------build graph--------------------\")\n",
    "    f = open(filename, \"rb\")\n",
    "    g = nx.read_edgelist(f)\n",
    "    return g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_extraction(g, pos_num, neg_num, neg_mode, neg_distance=2, delete=1):\n",
    "    \"\"\"\n",
    "\n",
    "    :param g:  the graph\n",
    "    :param pos_num: the number of positive samples\n",
    "    :param neg_num: the number of negative samples\n",
    "    :param neg_distance: the distance between two nodes in negative samples\n",
    "    :param delete: if delete ==0, don't delete positive edges from graph\n",
    "    :return: pos_sample is a list of positive edges, neg_sample is a list of negative edges\n",
    "    \"\"\"\n",
    "\n",
    "    print(\"----------------extract positive samples--------------------\")\n",
    "    # randomly select pos_num as test edges\n",
    "    pos_sample = random.sample(g.edges(), pos_num)\n",
    "    sample_g = nx.Graph()\n",
    "    sample_g.add_edges_from(pos_sample, positive=\"True\")\n",
    "    nx.write_edgelist(sample_g, r\"data\\sample_positive_\" +str(pos_num)+ \".txt\", data=['positive'])\n",
    "\n",
    "    # adding non-existing edges\n",
    "    print(\"----------------extract negative samples--------------------\")\n",
    "    i = 0\n",
    "    neg_g = nx.Graph()\n",
    "    produce_fake_edge(g,neg_g,neg_num)\n",
    "    nx.write_edgelist(neg_g, r\"data\\sample_negative_\" +str(neg_num)+ \".txt\", data=[\"positive\"])\n",
    "    neg_sample = neg_g.edges()\n",
    "    neg_g.add_edges_from(pos_sample,positive=\"True\")\n",
    "    nx.write_edgelist(neg_g, r\"data\\sample_combine_\" +str(pos_num + neg_num)+ \".txt\", data=[\"positive\"])\n",
    "\n",
    "    # remove the positive sample edges, the rest is the training set\n",
    "    if delete == 0:\n",
    "        return pos_sample, neg_sample\n",
    "    else:\n",
    "        g.remove_edges_from(pos_sample)\n",
    "        nx.write_edgelist(g, r\"data\\training.txt\", data=False)\n",
    "\n",
    "        return pos_sample, neg_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_extraction(g, pos_sample, neg_sample, feature_name, model=\"single\", combine_num=5):\n",
    "\n",
    "    data = []\n",
    "    if model == \"single\":\n",
    "        print (\"-----extract feature:\", feature_name.__name__, \"----------\")\n",
    "        preds = feature_name(g, pos_sample)\n",
    "        feature = [feature_name.__name__] + [i[2] for i in preds]\n",
    "        label = [\"label\"] + [\"Pos\" for i in range(len(feature))]\n",
    "        preds = feature_name(g, neg_sample)\n",
    "        feature1 = [i[2] for i in preds]\n",
    "        feature = feature + feature1\n",
    "        label = label + [\"Neg\" for i in range(len(feature1))]\n",
    "        data = [feature, label]\n",
    "        data = transpose(data)\n",
    "        print(\"----------write the feature to file---------------\")\n",
    "        write_data_to_file(data, r\"data\\features_\" + model + \"_\" + feature_name.__name__ + \".csv\")\n",
    "    else:\n",
    "        label = [\"label\"] + [\"1\" for i in range(len(pos_sample))] + [\"0\" for i in range(len(neg_sample))]\n",
    "        for j in feature_name:\n",
    "            print (\"-----extract feature:\", j.__name__, \"----------\")\n",
    "            preds = j(g, pos_sample)\n",
    "\n",
    "            feature = [j.__name__] + [i[2] for i in preds]\n",
    "            preds = j(g, neg_sample)\n",
    "            feature = feature + [i[2] for i in preds]\n",
    "            data.append(feature)\n",
    "\n",
    "        data.append(label)\n",
    "        data = transpose(data)\n",
    "        print(\"----------write the features to file---------------\")\n",
    "        write_data_to_file(data, r\"data\\features_\" + model + \"_\" + str(combine_num) + \".csv\")\n",
    "    return data\n",
    "\n",
    "\n",
    "def write_data_to_file(data, filename):\n",
    "    csvfile = open(filename, \"w\")\n",
    "    writer = csv.writer(csvfile)\n",
    "    for i in data:\n",
    "        writer.writerow(i)\n",
    "    csvfile.close()\n",
    "\n",
    "\n",
    "def transpose(data):\n",
    "    return [list(i) for i in zip(*data)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(filename=r\"data\\facebook_combined.txt\", pos_num=0.1, neg_num=0.1, model=\"combined\", combine_num=1,\n",
    "         feature_name=common_neighbors, neg_mode=\"hard\"):\n",
    "    if combine_num==2:\n",
    "        pos_num=0.08\n",
    "        neg_num=0.08\n",
    "    if combine_num==3:\n",
    "        pos_num=0.1\n",
    "        neg_num=0.1\n",
    "    g = create_graph_from_file(filename)\n",
    "    num_edges = g.number_of_edges()\n",
    "    pos_num = int(num_edges * pos_num)\n",
    "    neg_num = int(num_edges * neg_num)\n",
    "    pos_sample, neg_sample = sample_extraction(g, pos_num, neg_num,neg_mode)\n",
    "    train_data = feature_extraction(g, pos_sample, neg_sample, feature_name, model, combine_num)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#______________________Entry Point________________________\n",
    "#Fn: Name of data set you want to run this code for, and cn is a integer for that dataset(any integer will work but different for each dataset)\n",
    "#By default it is set to Twitter Data Set\n",
    "#The project was run using Facebook and Twitter dataset but it works with any social network dataset from http://snap.stanford.edu/data/\n",
    "#Following Scoring Methods are used to construct feature Set----\n",
    "#common_neighbors,resource_allocation_index, jaccard_coefficient, adamic_adar_index, preferential_attachment\n",
    "#SVM ANN and Logistic Regresssion is used for classificaion\n",
    "fn=r\"data\\facebook_combined.txt\";\n",
    "cn=3;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------build graph--------------------\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'data\\\\facebook_combined.txt'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-15-be8e567d6262>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m#Run this line to genrate feature Set\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"combined\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcombine_num\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeature_name\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeature_set\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mneg_mode\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"easy\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-13-e2e229420451>\u001b[0m in \u001b[0;36mmain\u001b[1;34m(filename, pos_num, neg_num, model, combine_num, feature_name, neg_mode)\u001b[0m\n\u001b[0;32m      7\u001b[0m         \u001b[0mpos_num\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      8\u001b[0m         \u001b[0mneg_num\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 9\u001b[1;33m     \u001b[0mg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_graph_from_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     10\u001b[0m     \u001b[0mnum_edges\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumber_of_edges\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     11\u001b[0m     \u001b[0mpos_num\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_edges\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mpos_num\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-10-7ec40c9bb4de>\u001b[0m in \u001b[0;36mcreate_graph_from_file\u001b[1;34m(filename)\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcreate_graph_from_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m     \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"----------------build graph--------------------\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m     \u001b[0mf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"rb\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      4\u001b[0m     \u001b[0mg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_edgelist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data\\\\facebook_combined.txt'"
     ]
    }
   ],
   "source": [
    "#Run this line to genrate feature Set\n",
    "main(filename=fn,model=\"combined\",combine_num=cn, feature_name=feature_set, neg_mode=\"easy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "r=np.loadtxt(open(r\"data\\features_combined_\"+str(cn)+\".csv\", \"rb\"), delimiter=\",\", skiprows=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "l,b=r.shape;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.shuffle(r);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "26468"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_l=int(0.75*l)\n",
    "X_train=r[0:train_l,0:b-1]\n",
    "Y_train=r[0:train_l,b-1]\n",
    "test = r[train_l:l,:]\n",
    "X_test=test[:,0:b-1]\n",
    "Y_test=test[:,b-1]\n",
    "X_train = normalize(X_train, axis=0, norm='max')\n",
    "X_test = normalize(X_test, axis=0, norm='max')\n",
    "scaler = StandardScaler()  \n",
    "scaler.fit(X_train)  \n",
    "X_train = scaler.transform(X_train)  \n",
    "X_test = scaler.transform(X_test)\n",
    "len(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mySvm(training, training_labels, testing, testing_labels):\n",
    "    #Support Vector Machine\n",
    "    start = datetime.datetime.now()\n",
    "    clf = svm.SVC()\n",
    "    clf.fit(training, training_labels)\n",
    "    print (\"+++++++++ Finishing training the SVM classifier ++++++++++++\")\n",
    "    result = clf.predict(testing)\n",
    "\n",
    "    print (\"SVM accuracy:\", accuracy_score(testing_labels, result))\n",
    "    #keep the time\n",
    "    finish = datetime.datetime.now()\n",
    "    print ((finish-start).seconds)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+++++++++ Finishing training the SVM classifier ++++++++++++\n",
      "SVM accuracy: 0.6791112454655381\n",
      "18\n"
     ]
    }
   ],
   "source": [
    "#Run this to for SVM classification\n",
    "svmres=mySvm(X_train,Y_train,X_test,Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logistic(training, training_labels, testing, testing_labels):\n",
    "    clf = LogisticRegression(random_state=0, solver='lbfgs',multi_class='ovr')\n",
    "    start = datetime.datetime.now()\n",
    "    clf.fit(training, training_labels)\n",
    "    result=clf.predict(testing)\n",
    "    print (\"+++++++++ Finishing training the Linear classifier ++++++++++++\")\n",
    "    print (\"Linear accuracy:\", accuracy_score(testing_labels, result))\n",
    "    #keep the time\n",
    "    finish = datetime.datetime.now()\n",
    "    print ((finish-start).seconds)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+++++++++ Finishing training the Linear classifier ++++++++++++\n",
      "Linear accuracy: 0.6653567110036276\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "#Run this for Logistic Regression\n",
    "logres=logistic(X_train,Y_train,X_test,Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ANN(training, training_labels, testing, testing_labels):\n",
    "    clf = MLPClassifier(solver='adam', alpha=1e-5,hidden_layer_sizes=(12,9), random_state=1, max_iter=1000)\n",
    "    start = datetime.datetime.now()\n",
    "    clf.fit(training, training_labels)\n",
    "    print (\"+++++++++ Finishing training the ANN classifier ++++++++++++\")\n",
    "    result = clf.predict(testing)\n",
    "\n",
    "    print (\"ANN accuracy:\", accuracy_score(testing_labels, result))\n",
    "    #keep the time\n",
    "    finish = datetime.datetime.now()\n",
    "    print ((finish-start).seconds)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+++++++++ Finishing training the ANN classifier ++++++++++++\n",
      "ANN accuracy: 0.6806227327690447\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "# Run this for ANN classification\n",
    "annres=ANN(X_train,Y_train,X_test,Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
